# Definition for a binary tree node.
class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

class Solution:
    def isSubtree(self, root: TreeNode, subRoot: TreeNode) -> bool:
        return self.traverse(root, subRoot)
    
    def traverse(self, root, subRoot):
        if self.helper(root, subRoot):
            return True
        
        if not root.left and not root.right:
            return False
        
        if root.left and root.right:
            return self.traverse(root.left, subRoot) or self.traverse(root.right, subRoot)

        if root.left:
            return self.traverse(root.left, subRoot)

        if root.right:
            return self.traverse(root.right, subRoot)
        
        
    def helper(self, root, subRoot):
        if root is None and subRoot is None:
            return True

        if (not root and subRoot) or (root and not subRoot):
            return False

        if root.val != subRoot.val:
            return False
        
        l = self.helper(root.left, subRoot.left) 
        r = self.helper(root.right, subRoot.right)
        return l and r


if __name__ == '__main__':
    s = Solution()
    root = TreeNode(3, 
            TreeNode(4, 
                TreeNode(1), 
                TreeNode(2, 
                    TreeNode(0))),
            TreeNode(5))
    subRoot = TreeNode(4, 
                TreeNode(1), 
                TreeNode(2))
    
    ans = s.isSubtree(root, subRoot)
    print(ans)